{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 343
    },
    "colab_type": "code",
    "id": "DcM6MNPX_ZJS",
    "outputId": "1ccea931-f78b-45f0-b6cf-c018dcdaa860"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fatal: destination path 'TensorNetwork' already exists and is not an empty directory.\n",
      "Processing ./TensorNetwork\n",
      "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.6/dist-packages (from tensornetwork==0.4.0) (1.18.4)\n",
      "Requirement already satisfied: graphviz>=0.11.1 in /usr/local/lib/python3.6/dist-packages (from tensornetwork==0.4.0) (0.14)\n",
      "Requirement already satisfied: opt_einsum>=2.3.0 in /usr/local/lib/python3.6/dist-packages (from tensornetwork==0.4.0) (3.2.1)\n",
      "Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.6/dist-packages (from tensornetwork==0.4.0) (2.10.0)\n",
      "Requirement already satisfied: scipy>=1.1 in /usr/local/lib/python3.6/dist-packages (from tensornetwork==0.4.0) (1.4.1)\n",
      "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from h5py>=2.9.0->tensornetwork==0.4.0) (1.12.0)\n",
      "Building wheels for collected packages: tensornetwork\n",
      "  Building wheel for tensornetwork (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
      "  Created wheel for tensornetwork: filename=tensornetwork-0.4.0-cp36-none-any.whl size=263053 sha256=de2f64c327c1074c9eafb9d7c3d7aa3bfd4dc076a640c7aadb912b2e566ab6e4\n",
      "  Stored in directory: /tmp/pip-ephem-wheel-cache-5_ld1sxs/wheels/f0/25/c0/f94fcb8f0e82252f2ee53dc257fb4b039cc2184b321375ed18\n",
      "Successfully built tensornetwork\n",
      "Installing collected packages: tensornetwork\n",
      "  Found existing installation: tensornetwork 0.4.0\n",
      "    Uninstalling tensornetwork-0.4.0:\n",
      "      Successfully uninstalled tensornetwork-0.4.0\n",
      "Successfully installed tensornetwork-0.4.0\n"
     ]
    }
   ],
   "source": [
    "!git clone https://github.com/google/TensorNetwork.git\n",
    "!pip install ./TensorNetwork"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "o96UaYqz_auC"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensornetwork as tn\n",
    "import numpy as np\n",
    "from tensornetwork.tn_keras.dense import DenseDecomp\n",
    "from tensornetwork.tn_keras.mpo import DenseMPO\n",
    "from tensorflow.keras.layers import Dense\n",
    "from tensorflow.keras.models import Sequential\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "FDRj9bX1AjO3"
   },
   "outputs": [],
   "source": [
    "def dummy_data(input_dim):\n",
    "    np.random.seed(42)\n",
    "    # Generate dummy data for use in tests\n",
    "    data = np.random.randint(10, size=(100, input_dim))\n",
    "    labels = np.concatenate((np.ones((50, 1)), np.zeros((50, 1))), axis=0)\n",
    "    return data, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "7PGfZJwq3CIu"
   },
   "source": [
    "# Build Base Model and Tensorized Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-kGUGRr5CJyj"
   },
   "outputs": [],
   "source": [
    "data, labels = dummy_data(1296)\n",
    "\n",
    "# Build a fully connected network\n",
    "model = Sequential()\n",
    "model.add(Dense(512, use_bias=True, activation='relu', input_shape=(data.shape[1],)))\n",
    "model.add(Dense(128, use_bias=True, activation='relu'))\n",
    "model.add(Dense(1, use_bias=True, activation='sigmoid'))\n",
    "\n",
    "# Build the same fully connected network using TN layer DenseDecomp\n",
    "decomp_model = Sequential()\n",
    "decomp_model.add(DenseDecomp(512, decomp_size=64, use_bias=True, activation='relu', input_shape=(data.shape[1],)))\n",
    "decomp_model.add(DenseDecomp(128, decomp_size=64, use_bias=True, activation='relu'))\n",
    "decomp_model.add(DenseDecomp(1, decomp_size=8, use_bias=True, activation='sigmoid'))\n",
    "\n",
    "# Build the same fully connected network using TN layer DenseMPO\n",
    "mpo_model = Sequential()\n",
    "mpo_model.add(DenseMPO(256, num_nodes=4, bond_dim=8, use_bias=True, activation='relu', input_shape=(1296,)))\n",
    "mpo_model.add(DenseMPO(81, num_nodes=4, bond_dim=4, use_bias=True, activation='relu'))\n",
    "mpo_model.add(Dense(1, use_bias=True, activation='sigmoid'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JnSHekWc25--"
   },
   "source": [
    "# Analyze Parameter Reduction from Tensorization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 255
    },
    "colab_type": "code",
    "id": "Rnx_Ny93_by4",
    "outputId": "21ee8502-10f1-4bf9-a6e9-84f01f0a924d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense (Dense)                (None, 512)               664064    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 128)               65664     \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 1)                 129       \n",
      "=================================================================\n",
      "Total params: 729,857\n",
      "Trainable params: 729,857\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 255
    },
    "colab_type": "code",
    "id": "AkV2xv69_fOh",
    "outputId": "341a383f-9bda-46ee-e6cb-71d2bd0560aa"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_1\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_decomp (DenseDecomp)   (None, 512)               116224    \n",
      "_________________________________________________________________\n",
      "dense_decomp_1 (DenseDecomp) (None, 128)               41088     \n",
      "_________________________________________________________________\n",
      "dense_decomp_2 (DenseDecomp) (None, 1)                 1033      \n",
      "=================================================================\n",
      "Total params: 158,345\n",
      "Trainable params: 158,345\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "decomp_model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 255
    },
    "colab_type": "code",
    "id": "jP4nX2Qb2aek",
    "outputId": "7ea2ec9a-0af5-4ad5-aed6-683f445e47df"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_2\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_mpo (DenseMPO)         (None, 256)               3712      \n",
      "_________________________________________________________________\n",
      "dense_mpo_1 (DenseMPO)       (None, 81)                561       \n",
      "_________________________________________________________________\n",
      "dense_3 (Dense)              (None, 1)                 82        \n",
      "=================================================================\n",
      "Total params: 4,355\n",
      "Trainable params: 4,355\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "mpo_model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "FlLxzu1pEfCC",
    "outputId": "c176fe93-aeb3-4cb8-ed36-cc4b45986bf4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compression factor from tensorization with DenseDecomp: 4.609283526476997\n",
      "Compression factor from tensorization with DenseMPO: 167.5905855338691\n"
     ]
    }
   ],
   "source": [
    "print(f'Compression factor from tensorization with DenseDecomp: {model.count_params() / decomp_model.count_params()}')\n",
    "print(f'Compression factor from tensorization with DenseMPO: {model.count_params() / mpo_model.count_params()}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "r4mawL7Y2xtk"
   },
   "source": [
    "# Train Models for Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 357
    },
    "colab_type": "code",
    "id": "xLWLyQ7bD5fB",
    "outputId": "0b5a09ab-dd7a-4bcb-9dcd-cefe9c470000"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "4/4 [==============================] - 0s 8ms/step - loss: 10.5600 - accuracy: 0.4600\n",
      "Epoch 2/10\n",
      "4/4 [==============================] - 0s 6ms/step - loss: 3.5192 - accuracy: 0.5000\n",
      "Epoch 3/10\n",
      "4/4 [==============================] - 0s 6ms/step - loss: 1.6383 - accuracy: 0.5600\n",
      "Epoch 4/10\n",
      "4/4 [==============================] - 0s 6ms/step - loss: 3.2225 - accuracy: 0.5000\n",
      "Epoch 5/10\n",
      "4/4 [==============================] - 0s 5ms/step - loss: 1.5394 - accuracy: 0.5000\n",
      "Epoch 6/10\n",
      "4/4 [==============================] - 0s 6ms/step - loss: 0.9699 - accuracy: 0.5600\n",
      "Epoch 7/10\n",
      "4/4 [==============================] - 0s 7ms/step - loss: 0.6281 - accuracy: 0.6700\n",
      "Epoch 8/10\n",
      "4/4 [==============================] - 0s 5ms/step - loss: 0.5347 - accuracy: 0.7500\n",
      "Epoch 9/10\n",
      "4/4 [==============================] - 0s 6ms/step - loss: 0.4560 - accuracy: 0.8200\n",
      "Epoch 10/10\n",
      "4/4 [==============================] - 0s 5ms/step - loss: 0.4128 - accuracy: 0.9100\n"
     ]
    }
   ],
   "source": [
    "model.compile(optimizer='adam',\n",
    "              loss='binary_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "# Train the model for 10 epochs\n",
    "history = model.fit(data, labels, epochs=10, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 357
    },
    "colab_type": "code",
    "id": "VmrvcNbaEuQK",
    "outputId": "aec77733-12b7-46db-9ccc-a25935a5a6ba"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "4/4 [==============================] - 0s 7ms/step - loss: 2.0713 - accuracy: 0.5700\n",
      "Epoch 2/10\n",
      "4/4 [==============================] - 0s 4ms/step - loss: 1.7343 - accuracy: 0.4500\n",
      "Epoch 3/10\n",
      "4/4 [==============================] - 0s 4ms/step - loss: 1.3718 - accuracy: 0.5100\n",
      "Epoch 4/10\n",
      "4/4 [==============================] - 0s 4ms/step - loss: 1.0777 - accuracy: 0.4900\n",
      "Epoch 5/10\n",
      "4/4 [==============================] - 0s 5ms/step - loss: 1.3767 - accuracy: 0.5400\n",
      "Epoch 6/10\n",
      "4/4 [==============================] - 0s 5ms/step - loss: 0.8864 - accuracy: 0.5500\n",
      "Epoch 7/10\n",
      "4/4 [==============================] - 0s 5ms/step - loss: 0.5909 - accuracy: 0.6800\n",
      "Epoch 8/10\n",
      "4/4 [==============================] - 0s 5ms/step - loss: 0.5715 - accuracy: 0.6900\n",
      "Epoch 9/10\n",
      "4/4 [==============================] - 0s 5ms/step - loss: 0.4912 - accuracy: 0.7200\n",
      "Epoch 10/10\n",
      "4/4 [==============================] - 0s 5ms/step - loss: 0.3498 - accuracy: 0.9000\n"
     ]
    }
   ],
   "source": [
    "decomp_model.compile(optimizer='adam',\n",
    "              loss='binary_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "# Train the tensorized model for 10 epochs\n",
    "history = decomp_model.fit(data, labels, epochs=10, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 357
    },
    "colab_type": "code",
    "id": "IMMtsmr5EyIx",
    "outputId": "12b8d799-a6e3-4fe4-c02d-080c77de1425"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "4/4 [==============================] - 0s 10ms/step - loss: 0.6926 - accuracy: 0.5100\n",
      "Epoch 2/10\n",
      "4/4 [==============================] - 0s 8ms/step - loss: 0.6890 - accuracy: 0.5100\n",
      "Epoch 3/10\n",
      "4/4 [==============================] - 0s 8ms/step - loss: 0.6856 - accuracy: 0.5000\n",
      "Epoch 4/10\n",
      "4/4 [==============================] - 0s 7ms/step - loss: 0.6813 - accuracy: 0.5300\n",
      "Epoch 5/10\n",
      "4/4 [==============================] - 0s 9ms/step - loss: 0.6776 - accuracy: 0.7200\n",
      "Epoch 6/10\n",
      "4/4 [==============================] - 0s 8ms/step - loss: 0.6733 - accuracy: 0.8400\n",
      "Epoch 7/10\n",
      "4/4 [==============================] - 0s 9ms/step - loss: 0.6689 - accuracy: 0.8300\n",
      "Epoch 8/10\n",
      "4/4 [==============================] - 0s 8ms/step - loss: 0.6635 - accuracy: 0.8400\n",
      "Epoch 9/10\n",
      "4/4 [==============================] - 0s 9ms/step - loss: 0.6581 - accuracy: 0.8100\n",
      "Epoch 10/10\n",
      "4/4 [==============================] - 0s 8ms/step - loss: 0.6501 - accuracy: 0.9300\n"
     ]
    }
   ],
   "source": [
    "mpo_model.compile(optimizer='adam',\n",
    "              loss='binary_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "# Train the tensorized model for 10 epochs\n",
    "history = mpo_model.fit(data, labels, epochs=10, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "1xS6vgC_1DYi"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "TN_Keras.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
